"""Utilities for serializing arbitrary objects."""
from typing import Any, Mapping, Type, TypeVar

from src.utils.typing import PathLike

import spacy
import torch
from torch import nn

# Useful type aliases to help with subclassing.
Properties = Mapping[str, Any]
Serialized = Mapping[str, Any]
Children = Mapping[str, Any]
Resolved = Mapping[str, Type['Serializable']]

SerializableT = TypeVar('SerializableT', bound='Serializable')


class Serializable:
    """Mixin that makes your class (loosely) serializable.

    Here, "serialization" means transforming an object to an equivalent
    python dictionary. The original object should be reconstructable
    from that dictionary alone; however, the serialization procedure need
    not be the same as doing `vars(obj)` even though that is the default
    behavior for this mixin.

    Objects are deserialized by passing the entire dictionary as keyword
    arguments to the class constructor. Subclasses should override the
    `properties` method so that the returned dictionary exactly reflects the
    kwargs expected by the constructor.

    Another feature: if any value returned by `properties` is itself
    serializable, this mixin can recursively serialize it! Simply override
    the `children` function to specify which properties it should try to
    recurse on. In particular, `children` should return a dictionary mapping
    each recursively serializable property name to a string key representing
    its type. These unique type keys should be mapped back to python types in
    your override of the @classmethod `resolve`.

    Finally, note this mixin will also *automatically* serialize some known
    types, e.g. spacy `Language` objects. You don't need to do anything
    special for this to happen; the mixin will find them and handle them
    appropriately.

    The use case for this mixin is to get the general serialization behavior of
    pkl without tying the payloads to this codebase. Indeed, you can open the
    payloads generated by `Serializable` anywhere, in any codebase, because
    they're just dictionaries! This is especially useful for distributing
    research models, which are frequently shared across many different
    codebases in settings where the source is heavily, frequently, and
    indiscriminately modified.
    """

    def __init__(self, **_: Any):
        """Initialize the object (necessary for type checking)."""
        super().__init__()

    def properties(self) -> Properties:
        """Return all properties needed to reconstruct the object.

        Properties are basically constructor (kw)args. By default, returns
        dictionary of object fields. Subclasses should override this method
        if different behavior is desired.
        """
        return vars(self)

    def serializable(self) -> Children:
        """Return type information for recursively serializable fields.

        It is not strictly required for subclasses to override this method.
        Its only use is when the recursively serialized field could be one of
        several types. In that case, this function should return a dictionary
        mapping the field name to some value (e.g a string) encoding the type
        information. Otherwise, subclasses can just override `resolve` and
        return the single type that field can take.
        """
        return {}

    def serialize(self, **kwargs: Any) -> Serialized:
        """Return object serialized as a dictionary.

        Other keyword arguments are unused, but accepted for type-checking
        compatibility.

        Raises:
            ValueError: If `serializable` returns keys that refer to
                non-Serializable types.

        Returns:
            Serialized: The serialized object.

        """
        properties = dict(self.properties())

        # Some object types require special serialization. We'll search for
        # those objects and do it live.
        queue = [properties]
        while queue:
            current = queue.pop()
            for key, value in current.items():
                if isinstance(value, dict):
                    queue.append(value)
                elif isinstance(value, spacy.Language):
                    config = value.config
                    payload = value.to_bytes()
                    current[key] = (config, payload)

        # Handle recursive serialization.
        children = self.serializable()
        for key, value in properties.items():
            if key in children and not isinstance(value, Serializable):
                raise ValueError(f'child "{key}" is not serializable '
                                 f'type: {type(value).__name__}')
            if isinstance(value, Serializable):
                properties[key] = value.serialize(**kwargs)

        return {'properties': properties, 'children': children}

    @classmethod
    def deserialize(
        cls: Type[SerializableT],
        serialized: Mapping[str, Any],
        **kwargs: Any,
    ) -> SerializableT:
        """Deserialize the object from its properties.

        Keyword arguments are passed to recursive `deserialize` calls.

        Args:
            serialized (Mapping[str, Any]): The serialized object.

        Returns:
            SerializableT: The deserialized object.

        """
        properties = dict(serialized['properties'])
        children = dict(serialized['children'])

        # First deserialize some known properties.
        queue = [properties]
        while queue:
            current = queue.pop()
            for key, value in current.items():
                if isinstance(value, dict):
                    queue.append(value)
                elif isinstance(value, tuple) and len(value) == 2:
                    config, payload = value
                    if isinstance(config, dict) and isinstance(payload, bytes):
                        lang = spacy.util.get_lang_class(config['nlp']['lang'])
                        nlp = lang.from_config(config)
                        nlp.from_bytes(payload)
                        current[key] = nlp

        # Then handle recursion, if requested and if possible.
        resolved = cls.resolve(children)
        for key, SerializableType in resolved.items():
            if key in properties and properties[key] is not None:
                properties[key] = SerializableType.deserialize(
                    properties[key], **kwargs)

        deserialized = cls(**properties)
        return deserialized

    @classmethod
    def resolve(cls, children: Children) -> Resolved:
        """Resolve Serializable types for all children."""
        return {}


SerializableModuleT = TypeVar('SerializableModuleT',
                              bound='SerializableModule')


class SerializableModule(Serializable, nn.Module):
    """A serializable torch module.

    This class provides `save` and `load` functions in addition to the
    `serialize` and `deserialize` functions provided by `Serializable`.
    Additionally, if one of the properties is named `state_dict`, this class
    will call `load_state_dict` on it.
    """

    def __init__(self, **_: Any):
        """Initialize the module."""
        super().__init__()

    def serialize(self, state_dict: bool = True, **kwargs: Any) -> Serialized:
        """Serialize the module, including its state dict.

        Keyword arguments are forwarded to `Serializable.serialize`.

        Args:
            state_dict (bool, optional): Include the module's parameters in
                the payload. Defaults to True.

        Returns:
            Serialized: The serialized module.

        """
        serialized = dict(super().serialize(state_dict=False, **kwargs))
        if state_dict:
            serialized['state_dict'] = self.state_dict()
        return serialized

    def save(self, file: PathLike, **kwargs: Any) -> None:
        """Save the featurizer to the given file so it can be reconstructed.

        This function simply calls `SerializableModule.serialize` and saves
        its output to a dictionary. Will also save model parameters if told to.

        Keyword arguments are forwarded to `serialize`.

        Args:
            file (PathLike): File to save model info in.

        """
        serialized = self.serialize(**kwargs)
        torch.save(serialized, file)

    @classmethod
    def deserialize(cls: Type[SerializableModuleT],
                    serialized: Mapping[str, Any],
                    strict: bool = False,
                    load_state_dict: bool = True,
                    **kwargs: Any) -> SerializableModuleT:
        """Instantiate the module from its properties.

        Keyword arguments are forwarded to `Serializable.deserialize`.

        Args:
            serialized (Mapping[str, Any]): Module properties. If state_dict
                is in the properties, this method will call
                `torch.nn.Module.load_state_dict` on it.
            strict (bool, optional): Forwarded to
                `torch.nn.Module.load_state_dict`.
            load_state_dict (bool, optional): If set, load the serialized
                module parameters. Defaults to True.

        Returns:
            SerializableModuleT: Instantiated module.

        """
        serialized = {**serialized}  # We mutate the dict, so copy it!
        state_dict = serialized.pop('state_dict', None)
        module = super(SerializableModule,
                       cls).deserialize(serialized,
                                        load_state_dict=False,
                                        **kwargs)
        if state_dict is not None and load_state_dict:
            module.load_state_dict(state_dict, strict=strict)
        return module

    @classmethod
    def load(cls: Type[SerializableModuleT], file: PathLike,
             **kwargs: Any) -> SerializableModuleT:
        """Load the module from the given path.

        Keyword arguments are forwarded to `torch.load`.

        Args:
            file (PathLike): File to load module from.

        Returns:
            SerializableModuleT: Loaded module.

        """
        payload = torch.load(file, **kwargs)
        return cls.deserialize(payload)
